package ppbot;

import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;


class NeuralNetwork implements GeneticIndividual {
        
    private double lambda = 2;          ///< the sigmoide sharpness
    
    public final int inputCount;        ///< input neuron count
    public final int outputCount;       ///< output neuron count
    public final int neuronCount;       ///< neuron count
    
    public final int[] layers;          ///< nn structure i.e. 3-5-3

    private List<Neuron> neurons;        ///< neuron strucutre

    /* genetic parametrs of nn */
    public double geneticFitness;       ///< fitness of the genetic individual
    public boolean elite;               ///< flag for elitism
    public double mutStrength = 1;      ///< mutation disturbance        
                   
    public class Neuron implements Cloneable, Serializable {
        
        public double output;
        
        private Neuron[] inputNeurons;
        
        private double[] weights;
        
        private double bias;
        
        private int inputCount;
        
        public boolean isInputNeuron()
        {
            return inputCount == 0;
        }
        
        public Neuron(int inputCount)
        {
            this.inputCount = inputCount;
            weights = new double[inputCount];
            inputNeurons = new Neuron[inputCount];
            bias = 0;
        }
                
        public void setWeight(int index, double weight)
        {
            weights[index] = weight;
        }
        
        public void setInputNeuron(int index, Neuron neuron)
        {
            inputNeurons[index] = neuron;
        }
        
        private double sigmoide(double x)
        {
            x = lambda * x;
            if (x < -1e+10) return 0;
            if (x > +1e+10) return 1;
            return 1 / (double)( 1 + Math.exp(-x));
        }
                
        public double compute()
        {
            if (inputCount == 0) return output;

            double sum = getBias();
            for(int i=0; i<inputCount; i++)
            {
                sum+=weights[i]*inputNeurons[i].output;
            }
            return output = sigmoide(sum);
        }

        public double getBias() {
            return bias;
        }

        public void setBias(double bias) {
            this.bias = bias;
        }
        
        public void initWeightsRandomly() {
            setBias(4 * Math.random() - 2);
            for (int i = 0; i < weights.length; i++) 
                weights[i] = 2 * Math.random() - 1;
        }
        
        @Override
        public Neuron clone()
        {
            try {
                Neuron newN = (Neuron) super.clone();
                newN.weights = weights.clone();
                newN.inputNeurons = inputNeurons.clone();
                return newN;
            } catch (CloneNotSupportedException ex) { return null; }
        }
        
        @Override
        public String toString()
        {
            StringBuilder str = new StringBuilder();
            str.append(String.format("b = %.2f, w = [", bias));
            for(double w : weights)
                str.append(String.format("%.2f ", w));
            str.append("]");
            return str.toString();            
        }
        
        public void crossover(Neuron neuron, boolean isThisElite)
        {
            if (isThisElite)
            {
                neuron.weights = weights.clone();
                neuron.bias = bias;
            } else {
                double[] w = neuron.weights;
                neuron.weights = weights;
                weights = w;
                double b = neuron.bias;
                neuron.bias = bias;
                bias = b;
            }
        }
        
    }
    
    private static Random random = new Random();

    private void initNetworkStructure(int[] layers)
    {
        int pos = layers[0];
        int sum = 0;
        for(int i=1; i<layers.length; i++)
        {
            for(int j=0; j<layers[i]; j++)
            {                
                Neuron neuron = neurons.get(pos++);
                for(int k=0; k<layers[i-1]; k++)
                    neuron.setInputNeuron(k, neurons.get(sum + k));
            }
            sum+=layers[i-1];
        }
    }
        
    public NeuralNetwork(int[] layers)
    {        
        int sum = 0;
        for(int i=0; i<layers.length; i++)
            sum+=layers[i];
        neuronCount = sum;
        inputCount = layers[0];
        outputCount = layers[layers.length-1];
        this.layers = layers;
        neurons = new ArrayList<Neuron>(this.neuronCount);
        
        for(int i=0; i<layers[0]; i++)
            neurons.add(new Neuron(0));                
        
        for(int i=1; i<layers.length; i++)
            for(int j=0; j<layers[i]; j++)
                neurons.add(new Neuron(layers[i-1]));
        
        initNetworkStructure(layers);                
    }
    
    public void setInput(int index, double input)
    {
        neurons.get(index).output = input;
    }
    
    public double getOutput(int index)
    {
        return neurons.get(neuronCount - outputCount + index).output;
    }
    
    public void compute()
    {
        for (Neuron neuron : neurons)
            neuron.compute();
    }   
    
    public void initWeightsRandomly() {
        for (Neuron neuron : neurons)
            neuron.initWeightsRandomly();
    }

    @Override
    public void initializeRandomly() {
        initWeightsRandomly();
        /* ... */
    }

    @Override
    public double getFitness() {
        return geneticFitness;
    }
    
    @Override
    public void mutate(double mutRate) {
        
        if (elite) return;
        
        for(Neuron neuron : neurons)
            if (!neuron.isInputNeuron())
            {
                if (random.nextDouble() < mutRate) 
                    neuron.bias += mutStrength * random.nextGaussian();            

                for(int i=0; i<neuron.weights.length; i++)
                    if (random.nextDouble() < mutRate) 
                        neuron.weights[i]+=  mutStrength * random.nextGaussian();
            }
    }

    @Override
    public NeuralNetwork clone() {        
        NeuralNetwork newNN = null;
        try {
            newNN = (NeuralNetwork) super.clone();
            
        } catch (CloneNotSupportedException e) { }    
                
        newNN.neurons = (ArrayList<Neuron>)((ArrayList<Neuron>)neurons).clone();
                        
        for(int i=0; i<neurons.size(); i++)
            newNN.neurons.set(i, neurons.get(i).clone());
        
        newNN.initNetworkStructure(newNN.layers);
        
        return newNN;
    }

    @Override
    public String toString()
    {
        StringBuilder str = new StringBuilder();
        str.append("{");        
        for(Neuron neuron : neurons) 
            if (!neuron.isInputNeuron())
                str.append(neuron.toString());
        str.append("}");        
        return str.toString();
    }
    
    /* just for debuging of the network */
    public static void main(String[] args) {    
    }

    @Override
    public void setElite(boolean b) {
        elite = b;
    }

    @Override
    public boolean isElite() {
        return elite;
    }

    @Override
    public void crossover(GeneticIndividual indi) {
        NeuralNetwork partner = (NeuralNetwork)indi;
        int index = (int)((neuronCount - inputCount) * Math.random());
        if (partner.elite)
            partner.neurons.get(index).crossover(neurons.get(index), true);        
        else
            neurons.get(index).crossover(partner.neurons.get(index), elite);
    }

    @Override
    public void serialize(ObjectOutputStream oos) throws IOException {
        for(Neuron neuron : neurons)
        {
            oos.writeObject(neuron.weights);
            oos.writeDouble(neuron.bias);
        }
    }

    @Override
    public void deserialize(ObjectInputStream ois) throws IOException, ClassNotFoundException {
        for(Neuron neuron : neurons)
        {
            neuron.weights = (double[])ois.readObject();
            neuron.bias = ois.readDouble();
        }
    }

}